# -*- coding: utf-8 -*-
"""
Created on Thu Oct 31 17:45:02 2024

@author: Yunhong Che
"""

from beep import structure
import os
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from scipy.interpolate import interp1d
import matplotlib.pyplot as plt
import random
import seaborn as sns
from matplotlib import rcParams
import pandas as pd
from sklearn import metrics
from matplotlib.colors import Normalize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import matplotlib.colors as mcolors
import time
from pyswarm import pso
from scipy.stats import gaussian_kde
from scipy.signal import savgol_filter
from sklearn.model_selection import KFold
from scipy.stats import linregress
from matplotlib.lines import Line2D
from scipy.optimize import minimize
from scipy.optimize import differential_evolution
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from xgboost import XGBRegressor
from sklearn.model_selection import GridSearchCV
import time
from bayes_opt import BayesianOptimization
from sko.GA import GA
from scipy.spatial import distance
from concurrent.futures import ThreadPoolExecutor
from joblib import Parallel, delayed
import cma
from scipy.optimize import approx_fprime
from numpy.linalg import matrix_rank
from matplotlib.colors import LinearSegmentedColormap
import warnings
warnings.filterwarnings("ignore")

def set_random_seed(seed):
    # seed setting
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

font_properties = {'family': 'Times New Roman', 'size': 12}
rcParams.update({'font.family': font_properties['family'], 'font.size': font_properties['size']})

#%%
nominal_capacity = 4.84

# c/5 
OCPn_data = pd.read_csv('anode_SiO_Gr_discharge_Cover5_smoothed_dvdq_JS.csv')
OCPp_data = pd.read_csv('cathode_NCA_discharge_Cover5_smoothed_dvdq_JS.csv')

OCPn_SOC = OCPn_data['SOC_linspace'].values
OCPn_V = OCPn_data['Voltage'].values
OCPp_SOC = OCPp_data['SOC_linspace'].values
OCPp_V = OCPp_data['Voltage'].values[::-1].copy()  # 

# 
OCP_p = interp1d(OCPp_SOC, OCPp_V, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n = interp1d(OCPn_SOC, OCPn_V, kind='cubic', fill_value='extrapolate', bounds_error=False)

# c/40 
OCPn_data_40 = pd.read_csv('anode_SiO_Gr_discharge_Cover40_smooth_JS.csv')
OCPp_data_40 = pd.read_csv('cathode_NCA_discharge_Cover40_smooth_JS.csv')

OCPn_SOC_40 = OCPn_data_40['SOC_linspace'].values
OCPn_V_40 = OCPn_data_40['Voltage'].values
OCPp_SOC_40 = OCPp_data_40['SOC_linspace'].values
OCPp_V_40 = OCPp_data_40['Voltage'].values[::-1].copy()

OCP_p_40 = interp1d(OCPp_SOC_40, OCPp_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)
OCP_n_40 = interp1d(OCPn_SOC_40, OCPn_V_40, kind='cubic', fill_value='extrapolate', bounds_error=False)

#%% only needed for the first run to extract data

def read_resval_data(file_name, idx_seed,  predic_c_rate, num_points=1000):
    
    inputs = []
    outputs = []
    nominal_capacity = 4.84
    datapath = structure.MaccorDatapath.from_file(file_name)
    all_cycle_types = [
            "start_discharge", 
            "C/80_Cycle",
            "GITT",
            "C/40_Cycle",
            "0.05A_Cycle_mistake",
            "C/10_Cycle",
            "C/7_cycle",
            "C/5_Cycle",
            "1C_Cycle",
            "2C_Cycle",
            "charge_for_storage",
    ]
    datapath.raw_data["cycle_type"] = datapath.raw_data["cycle_index"].apply(lambda x: all_cycle_types[x])
    
    data=datapath.raw_data
    
    
    for round_idx in range(1):
        set_random_seed(idx_seed+round_idx)
        filtered_data = data[data["cycle_type"] == predic_c_rate]
        
        if len(filtered_data)==0:
            continue
        V = np.array(filtered_data['voltage'])
        I = np.array(filtered_data['current'])
        Q = np.array(filtered_data['discharge_capacity'])
        t = np.array(filtered_data['step_time'])
        
        for charge_end_idx in range(2,len(Q)):
            if Q[charge_end_idx]>Q[charge_end_idx-1] and I[charge_end_idx]<0:
                break
        for discharge_end_idx in range(2,len(Q)):
            if V[discharge_end_idx] <= 2.7:
                break
        if np.max(Q[charge_end_idx:discharge_end_idx])<2 :
            continue
        
        start_index = charge_end_idx    
        V = V[start_index:discharge_end_idx+1] / 4.2
        Q = Q[start_index:discharge_end_idx+1] / nominal_capacity-Q[start_index] / nominal_capacity
        
        x_new = np.linspace(0, 1, num_points)
        
        interp_ocv_v = interp1d(np.linspace(0, 1, len(V)), V, kind='linear')
        interp_ocv_q = interp1d(np.linspace(0, 1, len(Q)), Q, kind='linear')
        discharge_voltage_interp = interp_ocv_v(x_new)
        discharge_capacity_interp = interp_ocv_q(x_new)
        
        input_data = np.stack([discharge_voltage_interp, discharge_capacity_interp, I[start_index+50]*np.ones(len(x_new))/4.84], axis=-1)
       
        
        filtered_data = data[data["cycle_type"] == "C/40_Cycle"]
        #% C/40
        V = np.array(filtered_data['voltage'])
        I = np.array(filtered_data['current'])
        Q = np.array(filtered_data['discharge_capacity'])
        t = np.array(filtered_data['step_time'])
        
        for charge_end_idx in range(2,len(Q)):
            if Q[charge_end_idx]>Q[charge_end_idx-1] and I[charge_end_idx]<0:
                break
        for discharge_end_idx in range(2,len(Q)):
            if V[discharge_end_idx] <= 2.7:
                break
        start_index = charge_end_idx    
        V = V[start_index:discharge_end_idx+1] / 4.2
        Q = Q[start_index:discharge_end_idx+1] / nominal_capacity-Q[start_index] / nominal_capacity
        
        x_new = np.linspace(0, 1, num_points)
        
        interp_ocv_v = interp1d(np.linspace(0, 1, len(V)), V, kind='linear')
        interp_ocv_q = interp1d(np.linspace(0, 1, len(Q)), Q, kind='linear')
        discharge_voltage_interp = interp_ocv_v(x_new)
        discharge_capacity_interp = interp_ocv_q(x_new)
        output_data = np.stack([discharge_voltage_interp, discharge_capacity_interp], axis=-1)
        
       
        # Add the processed sample to the list
        inputs.append(input_data)
        outputs.append(output_data)    
    
    # Convert input and output lists to NumPy arrays
    inputs = np.array(inputs)  # Shape (num_samples, num_points, num_features)
    outputs = np.array(outputs)  #(num_samples, num_points)
    
    return inputs, outputs


predict_rate_list=["C/40_Cycle",
                    "C/5_Cycle",
                    ]

predict_ocv_all_diff_rate = []
Cp_all_diff_rate = []
Cn_all_diff_rate = []
x0_all_diff_rate = []
y0_all_diff_rate = []
shap_all_Q_diff_rate = []
shap_all_V_diff_rate = []

calculated_ocv_all_diff_rate = []
true_ocv_all_diff_rate = []
C_rates_all_diff_rate = []
measure_all_cap_rate = []
all_inputs, all_outputs = [], []
all_cells = []
for predic_c_rate in predict_rate_list:
    print(predic_c_rate)
    folder_loc = '..\\data_code\\ResValData'
    folder_loc = os.path.abspath(folder_loc)
    file_list = [f for f in os.listdir(folder_loc) if f.startswith('ResVal')]
    nominal_capacity = 4.84
    
    inputs_cell, outputs_cell = [], []
    
    for i in range(0,len(file_list)): # train_batteries all_batteries 
        battery = file_list[i]
        # if battery =='ResVal_000084_0000ED.072':
        #     continue
        file_name = os.path.join(folder_loc, battery)
        # print("Loading cell", battery)
        cap = []
        input_data, output_data = read_resval_data(file_name,i,predic_c_rate)
        if len(input_data)==0:
            print('skip')
            continue
        inputs_cell.append(input_data)
        outputs_cell.append(output_data)
        all_cells.append((battery,predic_c_rate))
        
    inputs_cell = np.concatenate(inputs_cell, axis=0)
    outputs_cell = np.concatenate(outputs_cell, axis=0)    
    
    all_inputs.append(inputs_cell)
    all_outputs.append(outputs_cell)
    
all_inputs = np.concatenate(all_inputs, axis=0)
all_outputs = np.concatenate(all_outputs, axis=0)    


#%% OCV fit using mechanistic model and show the differences


## PSO optimization function
def objective_function(params, c_rate, measure_Q, measure_V, dof, object_loss):
    if dof == 4:
        Cp, Cn, x0, y0 = params
    elif dof == 3:
        Cp, Cn, NP_offset = params
        y0 = 0
        x0 = NP_offset
    elif dof == 2:
        Cp = 1 # nominal one
        y0 = 0
        NP_ratio, NP_offset = params
        Cn = Cp * NP_ratio
        x0 = NP_offset
    else:
        raise ValueError("DOF must be 2, 3, or 4")

    # SOC
    SOC_p = y0 + measure_Q / Cp
    SOC_n = x0 - measure_Q / Cn

    
    if c_rate == 'C/40_Cycle':
        Up = OCP_p_40(SOC_p)
        Un = OCP_n_40(SOC_n)
    elif c_rate == 'C/5_Cycle':
        Up = OCP_p(SOC_p)
        Un = OCP_n(SOC_n)
    else:
        raise ValueError("Unsupported c_rate type.")

    fitted_Voc = Up - Un
    
    if dof == 4:
        regularization = 0.01 * (Cp**2 + Cn**2 + x0**2 + y0**2)
    elif dof == 3:
        regularization = 0.01 * (Cp**2 + Cn**2 + NP_offset**2)
    elif dof == 2:
        regularization = 0.01 * (NP_ratio**2 + NP_offset**2)
    
    if object_loss=='eucl':
        measure_V_matrix = np.vstack((measure_Q, measure_V)).T  # (Q, V) measure
        fitted_Voc_matrix = np.vstack((measure_Q, fitted_Voc)).T  # (Q, V) fit
        error_matrix = distance.cdist(measure_V_matrix, fitted_Voc_matrix, "euclidean")
        error_vector = error_matrix.min(axis=1)
        error = error_vector.mean()
        total_loss = error+regularization
    elif object_loss=='mse':
        mse_error = np.mean((fitted_Voc - measure_V) ** 2)
        total_loss = mse_error+regularization
    elif object_loss=='dvf':
        dv_fit = np.gradient(fitted_Voc, measure_Q)
        dv_meas = np.gradient(measure_V, measure_Q)
        dv_error = np.mean((dv_fit - dv_meas) ** 2)
        total_loss = dv_error+regularization
    elif object_loss=='eucl_mse':
        measure_V_matrix = np.vstack((measure_Q, measure_V)).T  # (Q, V) measure
        fitted_Voc_matrix = np.vstack((measure_Q, fitted_Voc)).T  # (Q, V) fit
        error_matrix = distance.cdist(measure_V_matrix, fitted_Voc_matrix, "euclidean")
        error_vector = error_matrix.min(axis=1)
        error = error_vector.mean()
        mse_error = np.mean((fitted_Voc - measure_V) ** 2)
        total_loss = error+mse_error+regularization
    elif object_loss=='eucl_dvf':
        measure_V_matrix = np.vstack((measure_Q, measure_V)).T  # (Q, V) measure
        fitted_Voc_matrix = np.vstack((measure_Q, fitted_Voc)).T  # (Q, V) fit
        error_matrix = distance.cdist(measure_V_matrix, fitted_Voc_matrix, "euclidean")
        error_vector = error_matrix.min(axis=1)
        error = error_vector.mean()
        dv_fit = np.gradient(fitted_Voc, measure_Q)
        dv_meas = np.gradient(measure_V, measure_Q)
        dv_error = np.mean((dv_fit - dv_meas) ** 2)
        total_loss = error+dv_error+regularization
    elif object_loss=='mse_dvf':
        mse_error = np.mean((fitted_Voc - measure_V) ** 2)
        dv_fit = np.gradient(fitted_Voc, measure_Q)
        dv_meas = np.gradient(measure_V, measure_Q)
        dv_error = np.mean((dv_fit - dv_meas) ** 2)
        total_loss = mse_error+dv_error+regularization
    elif object_loss=='eucl_mse_dvf':
        measure_V_matrix = np.vstack((measure_Q, measure_V)).T  # (Q, V) measure
        fitted_Voc_matrix = np.vstack((measure_Q, fitted_Voc)).T  # (Q, V) fit
        error_matrix = distance.cdist(measure_V_matrix, fitted_Voc_matrix, "euclidean")
        error_vector = error_matrix.min(axis=1)
        error = error_vector.mean()
        mse_error = np.mean((fitted_Voc - measure_V) ** 2)
        dv_fit = np.gradient(fitted_Voc, measure_Q)
        dv_meas = np.gradient(measure_V, measure_Q)
        dv_error = np.mean((dv_fit - dv_meas) ** 2)
        total_loss = error+mse_error+dv_error+regularization
    else:
         raise  ValueError("Unknown loss")
    return total_loss


def get_bounds_by_dof(dof):
    if dof == 4:
        lb = [0.2, 0.2, 0, 0]     # Cp, Cn, x0, y0
        ub = [1.2, 1.2, 1.0, 1.0]
        names = ['Cp', 'Cn', 'x0', 'y0']
    elif dof == 3:
        lb = [0.2, 0.2, 0]        # Cp, Cn, NP_offset
        ub = [1.2, 1.2, 1.0]
        names = ['Cp', 'Cn', 'NP_offset']
    elif dof == 2:
        lb = [0.2, 0]             # NP_ratio, NP_offset
        ub = [1.2, 1.0]
        names = ['NP_ratio', 'NP_offset']
    else:
        raise ValueError("DOF must be 2, 3 or 4")
    return lb, ub, names

def run_pso(trial, Q, V, rate, dof, object_loss):
    lb, ub, _ = get_bounds_by_dof(dof)
    np.random.seed(trial)
    best_params, best_loss = pso(
        objective_function, lb, ub,
        args=(rate, Q, V, dof, object_loss),
        swarmsize=20, maxiter=200,
        # minstep=1e-5, minfunc=1e-5
    )
    return best_params, best_loss


def run_bo(trial, Q, V, rate, dof, object_loss):
    lb, ub, names = get_bounds_by_dof(dof)
    pbounds = {name: (l, u) for name, l, u in zip(names, lb, ub)}
    def black_box(**kwargs):
        params = [kwargs[name] for name in names]
        return -objective_function(params, rate, Q, V, dof, object_loss)

    optimizer = BayesianOptimization(f=black_box, pbounds=pbounds, random_state=trial)
    optimizer.maximize(init_points=20, n_iter=200)

    best_param = optimizer.max['params']
    best_param_list = [best_param[n] for n in names]
    return best_param_list, -optimizer.max['target']


def run_de(trial, Q, V, rate, dof, object_loss):
    bounds = list(zip(*get_bounds_by_dof(dof)[:2]))
    np.random.seed(trial)
    result = differential_evolution(
        objective_function, bounds,
        args=(rate, Q, V, dof, object_loss),
        popsize=20, maxiter=200,
        # init="latinhypercube",  #
        # mutation=0.8,
        # recombination=0.7,
          # tol=1e-5,
        seed=trial
    )
    return result.x, result.fun


def run_ga(trial, Q, V, rate, dof, object_loss):
    lb, ub, _ = get_bounds_by_dof(dof)
    np.random.seed(trial)
    ga = GA(
        func=lambda p: objective_function(p, rate, Q, V, dof, object_loss),
        n_dim=len(lb), size_pop=20, max_iter=200,
        lb=lb, ub=ub
    )
    best_p, best_f = ga.run()
    return best_p, best_f


def run_cmaes(trial, Q, V, rate, dof, object_loss):
    lb, ub, _ = get_bounds_by_dof(dof)
    bound_arr = np.array([lb, ub])
    x0 = np.mean(bound_arr, axis=0)
    sigma0 = 0.2 * (bound_arr[1] - bound_arr[0])
    
    cma_bounds = [bound_arr[0].tolist(), bound_arr[1].tolist()]
    opts = {'bounds': cma_bounds, 'seed': trial, 'maxiter': 200, 'popsize': 20} #, 'tolfun': 1e-5
    es = cma.CMAEvolutionStrategy(x0.tolist(), sigma0.mean(), opts)
    es.optimize(lambda p: objective_function(p, rate, Q, V, dof, object_loss))
    return es.result.xbest, es.result.fbest


def optimize_cycle(cell_inputs, cell_outputs, cell_rate, cell_name, opt_func, num_trials, dof, object_loss):
    random.seed(123)
    np.random.seed(123)
   
    measure_Q = cell_inputs[:, 1]
    measure_V = cell_inputs[:, 0] * 4.2
    best_params = None
    best_fopt = float('inf')
    
    trial_func_map = {
        'PSO': run_pso,
        'BO': run_bo,
        'DE': run_de,
        'GA': run_ga,
        'CMA-ES': run_cmaes
    }

    if opt_func not in trial_func_map:
        raise ValueError(f"Unsupported optimization algorithm: {opt_func}")
    
    results = Parallel(n_jobs=num_trials)(
        delayed(trial_func_map[opt_func])(trial, measure_Q, measure_V, cell_rate, dof, object_loss)
        for trial in range(num_trials)
    )

    for optimized_params, fopt in results:
        if fopt < best_fopt:
            best_fopt = fopt
            best_params = optimized_params
    
    best_params = [float(p) for p in best_params]
    best_fopt = float(best_fopt)
    # DOF 
    Cp_opt, Cn_opt, x0_opt, y0_opt = None, None, None, None
    if dof == 4:
        Cp_opt, Cn_opt, x0_opt, y0_opt = best_params
    elif dof == 3:
        Cp_opt, Cn_opt, NP_offset = best_params
        y0_opt = 0
        x0_opt = NP_offset
    elif dof == 2:
        Cp_nominal = 1
        NP_ratio, NP_offset = best_params
        Cn_opt = Cp_nominal * NP_ratio
        Cp_opt = Cp_nominal
        y0_opt = 0
        x0_opt = NP_offset

    # print(f"Best {opt_func} optimized parameters: Cell {cell_name} "
    #       f"C={measure_Q[-1]:.4f}, Cp={Cp_opt:.4f}, Cn={Cn_opt:.4f}, x0={x0_opt:.4f}, y0={y0_opt:.4f}, fopt={best_fopt:.6f}")

   
    
    SOC_p_fit = y0_opt + measure_Q / Cp_opt
    SOC_n_fit = x0_opt - measure_Q / Cn_opt

    if cell_rate == 'C/5_Cycle':
        Up_fit = OCP_p(SOC_p_fit)
        Un_fit = OCP_n(SOC_n_fit)
    elif cell_rate == 'C/40_Cycle':
        Up_fit = OCP_p_40(SOC_p_fit)
        Un_fit = OCP_n_40(SOC_n_fit)
    else:
        raise ValueError("Unsupported cell_rate")

    fitted_Voc = Up_fit - Un_fit

    return {
        'Cp_opt': Cp_opt,
        'Cn_opt': Cn_opt,
        'x0_opt': x0_opt,
        'y0_opt': y0_opt,
        'fitted_Voc': fitted_Voc,
        'cell_cap': cell_inputs[-1, 1],
        'Cq': measure_Q[-1],
        'fit_results': results
    }


plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['axes.edgecolor'] = 'black'
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['font.size'] = 12


num_trials = 10
opt_funcs_options = ['PSO','DE','GA','CMA-ES','BO'] #'PSO','DE','GA','CMA-ES','BO'
dofs = [ 2,3,4]
object_losses = ['eucl'] #'eucl','mse','dvf','eucl_mse','eucl_dvf','dvf_eucl','eucl_mse_dvf'
for object_loss in object_losses:
    for opt_func_trial in opt_funcs_options:
        for dof in dofs:
            random.seed(123)
            np.random.seed(123)
            torch.manual_seed(123)
            print(f"Running: {opt_func_trial} with DOF={dof} using object loss {object_loss}")
            start_time = time.time()
    
            Results = Parallel(n_jobs=56)(
                delayed(optimize_cycle)(
                    all_inputs[i], all_outputs[i], all_cells[i][1], all_cells[i][0],
                    opt_func_trial, num_trials, dof, object_loss
                ) for i in range(len(all_outputs))
            )
    
            elapsed_time = time.time() - start_time
            print('Time cost [min]:', elapsed_time/60)
    
            # 
            Cp_list, Cn_list, x0_list, y0_list = [], [], [], []
            OCV_fit_list, Cq_list, cap_list = [], [], []
            all_fit_results = []
    
            for result in Results:
                Cp_list.append(result.get('Cp_opt', np.nan))
                Cn_list.append(result.get('Cn_opt', np.nan))
                x0_list.append(result.get('x0_opt', np.nan))
                y0_list.append(result.get('y0_opt', np.nan))
                OCV_fit_list.append(result['fitted_Voc'])
                Cq_list.append(result['Cq'])
                cap_list.append(result['cell_cap'])
                all_fit_results.append(result['fit_results'])
    
            # 
            param_dim = len(all_fit_results[0][0][0])  # 
            fit_results_np = np.zeros((len(all_fit_results), num_trials, param_dim + 1))
            for i, trials in enumerate(all_fit_results):
                for j, (params, fopt) in enumerate(trials):
                    fit_results_np[i, j, :-1] = params
                    fit_results_np[i, j, -1] = fopt
    
            # 
            plt.figure(figsize=(20 / 2.54, 16 / 2.54), dpi=600)
            plt.ion()
    
            def plot_param_subplot(index, data, title):
                plt.subplot(2, 2, index)
                plt.rcParams['xtick.direction'] = 'in'
                plt.rcParams['ytick.direction'] = 'in'
                plt.tick_params(top=True, right=True, which='both')
                plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
                plt.plot(cap_list[:94], data[:94], 'o', color='#84c3b7')
                plt.plot(cap_list[94:], data[94:], 'o', color='#e68b81')
                plt.title(title)
    
            if dof in [2, 3, 4]:
                plot_param_subplot(1, Cp_list, 'Cp')
                plot_param_subplot(2, Cn_list, 'Cn')
            if dof >= 2:
                plot_param_subplot(3, x0_list, 'x0')
                plot_param_subplot(4, y0_list, 'y0')
    
            plt.tight_layout()
            plt.show()
    
            # 
            filename = f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}.npz"
            np.savez(
                filename,
                all_Cp_opt=np.array(Cp_list),
                all_Cn_opt=np.array(Cn_list),
                all_x0_opt=np.array(x0_list),
                all_y0_opt=np.array(y0_list),
                all_Cq=np.array(Cq_list),
                all_OCV_fit=np.array(OCV_fit_list, dtype=object),
                all_cell_cap=np.array(cap_list),
                all_cell_ocv=all_outputs[:len(OCV_fit_list), :, :],
                all_cell_vmea=all_inputs[:len(OCV_fit_list), :, :],
                all_cells=all_cells,
                time_consum=elapsed_time,
                all_fit_results=fit_results_np
            )
            print(f"✅ Saved: {filename}")

#%%
for object_loss in ['eucl']: #object_losses
    time_consum_all_optimization_by_dof = {dof: [] for dof in dofs}
    fitopt_all_optimization_by_dof = {dof: [] for dof in dofs}
    fitopt_all_min_optimization_by_dof = {dof: [] for dof in dofs}
    fitopt_all_std_optimization_by_dof = {dof: [] for dof in dofs}
    
    for dof in [2, 3, 4]:
        time_consum_all_optimization = []
        fitopt_all_optimization = []
        fitopt_all_min_optimization = []
        fitopt_all_var_optimization = []
        for opt_func_trial in opt_funcs_options:
            filename = f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}.npz"
            print(filename)
            norminal_c = 4.84
            data = np.load(filename, allow_pickle=True)
            all_Cq = data['all_Cq']
            all_Cp_opt = data['all_Cp_opt']*norminal_c
            all_Cn_opt = data['all_Cn_opt']*norminal_c
            all_x0_opt = data['all_x0_opt']
            all_y0_opt = data['all_y0_opt']
            all_OCV_fit = data['all_OCV_fit']
            all_cell_cap = data['all_cell_cap']
            all_cell_ocv = data['all_cell_ocv']
            all_cell_vmea = data['all_cell_vmea']
            all_cells = data['all_cells']
            all_fit_results = data['all_fit_results']
            time_consum_all_optimization.append(data['time_consum']/60)
            all_cell_Vreal = all_cell_ocv[:,:,0]*4.2
            all_cell_Qreal = all_cell_ocv[:,:,1]
            all_cell_Vm = all_cell_vmea[:,:,0]*4.2
            all_cell_Qm = all_cell_vmea[:,:,1]
            all_cell_Crate = all_cell_vmea[:,:,2]
            all_v_diff = all_cell_Vreal-all_OCV_fit
            all_q_diff = all_cell_Qreal*norminal_c-all_cell_Qm*norminal_c
            
            
            data_dict = {key: data[key] for key in data.files}
            # print("Keys in the file:", data.files)
            all_cells = data['all_cells']  
            cell_names = all_cells[:, 0]  
            rate_labels = all_cells[:, 1]  
            unique_labels = np.unique(rate_labels)
            
            split_data_dict = {label: {} for label in unique_labels}
            
            for key in data_dict:
                if key!='time_consum':
                    if data_dict[key].shape[0] == len(rate_labels):  
                        for label in unique_labels:
                            split_data_dict[label][key] = data_dict[key][rate_labels == label]
                    else:
                        for label in unique_labels:
                            split_data_dict[label][key] = data_dict[key]
            
            # for label in unique_labels:
            #     print(f"Data for rate {label}: keys -> {split_data_dict[label].keys()}")
        
            data_sets = []
            # ['C/40', 'C/5']
            assert len(unique_labels) == 2, "C/5 and C/40, need to change to suit more C-rates"
            
            label1, label2 = unique_labels  
            subset1, subset2 = split_data_dict[label1], split_data_dict[label2]
            
            data_sets = [
                (subset1['all_Cp_opt']*norminal_c, subset2['all_Cp_opt']*norminal_c, f'Cp {label1}', f'Cp {label2}'),
                (subset1['all_Cn_opt']*norminal_c, subset2['all_Cn_opt']*norminal_c, f'Cn {label1}', f'Cn {label2}'),
                (subset1['all_Cn_opt'] * subset1['all_x0_opt']*norminal_c + subset1['all_Cp_opt'] * subset1['all_y0_opt']*norminal_c, 
                 subset2['all_Cn_opt'] * subset2['all_x0_opt']*norminal_c + subset2['all_Cp_opt'] * subset2['all_y0_opt']*norminal_c, 
                 f'Cli {label1}', f'Cli {label2}')
            ]
            
            
            fit_coefficients = all_fit_results[:, :, -1]  
            min_coefficients = np.mean(fit_coefficients, axis=1)  
            fitopt_all_optimization.append([fit_coefficients[0:94].reshape(-1),fit_coefficients[94:].reshape(-1)])
            fitopt_all_min_optimization.append([min_coefficients[0:94],min_coefficients[94:]])
            
            fit_std_per_cell = np.std(fit_coefficients, axis=1)  # shape = (188, )
            fitopt_all_var_optimization.append([fit_std_per_cell[0:94], fit_std_per_cell[94:]])
        #
        time_consum_all_optimization_by_dof[dof] = time_consum_all_optimization
        fitopt_all_optimization_by_dof[dof] = fitopt_all_optimization
        fitopt_all_min_optimization_by_dof[dof] = fitopt_all_min_optimization
        fitopt_all_std_optimization_by_dof[dof] = fitopt_all_var_optimization
        
        set_random_seed(123)
        algorithms = ['PSO','DE','GA','CMA-ES','BO']
        # colors = ['red', 'blue', 'green', 'purple','c']
        colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61']
        markers = ['o', '^']  # o:(C/40), ^:(C/5)
        rate_labels = ["C/40", "C/5"]
        
        fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
        x_positions = np.array([1, 2, 3, 4, 5])  
        offset = 0.2 
        for algo_idx in range(len(algorithms)):
            for rate_idx in range(2):
                data = fitopt_all_optimization[algo_idx][rate_idx]
                x = x_positions[algo_idx] + (-1)**rate_idx * offset
                bp = ax.boxplot([data], 
                               positions=[x], 
                               widths=0.1,
                               patch_artist=True,
                               showfliers=False)
                for box in bp['boxes']:
                    box.set_facecolor(colors[algo_idx])
                    box.set_alpha(0.4)
                x_jittered = np.random.normal(x, 0.03, size=len(data))
                ax.scatter(x_jittered, data,
                          color=colors[algo_idx],
                          marker=markers[rate_idx],
                          alpha=0.7,
                          s=15,
                          edgecolors='white',
                          linewidths=0.0,
                          label=f'{algorithms[algo_idx]} ({rate_labels[rate_idx]})')
        ax.set_xticks(x_positions)
        ax.set_xticklabels(algorithms)
        ax.set_ylabel("Optimization target value")
        ax.grid(axis='y', linestyle='--', alpha=0.4)
        
        legend_elements = [
            Line2D([0], [0], marker='o', color='gray', label='C/40',
                  markersize=3, linestyle='None'),
            Line2D([0], [0], marker='^', color='gray', label='C/5',
                  markersize=3, linestyle='None')
        ]
        ax.legend(handles=legend_elements,
                  # bbox_to_anchor=(1.05, 1),
                  loc='upper right',
                  handletextpad=-0.2, 
                  frameon=False)
        
        plt.tight_layout()
        plt.show()
        
        
        fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
        
        bars = ax.bar(x_positions, 
                     time_consum_all_optimization,
                     width=0.4,
                     color=colors[:len(algorithms)],  
                     alpha=0.8)

        for bar in bars:
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{height:.1f}',
                   ha='center', va='bottom',
                   fontsize=8)
        
        # Formatting
        ax.set_ylabel('Time [min]')
        ax.set_xticks(x_positions)
        ax.set_xticklabels(algorithms)
        ax.set_ylim([0.8*min(time_consum_all_optimization),1.1*max(time_consum_all_optimization)])
        plt.tight_layout()
        plt.show()
        
        #%
        set_random_seed(123)
        fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
        x_positions = np.array([1, 2, 3, 4, 5])  # 
        offset = 0.2  # 
        
        mean_points = {0: [], 1: []}  # 0: C/40, 1: C/5
        
        for algo_idx in range(len(algorithms)):
            for rate_idx in range(2):
                
                data = fitopt_all_min_optimization[algo_idx][rate_idx]
                x = x_positions[algo_idx] + (-1)**rate_idx * offset
                mean_val = np.mean(data)
                mean_points[rate_idx].append((x, mean_val))
                vp = ax.violinplot([data], 
                                  positions=[x],
                                  widths=0.3,
                                  showmeans=False,
                                  showmedians=True,
                                  showextrema=False)
                
                for pc in vp['bodies']:
                    pc.set_facecolor(colors[algo_idx])
                    pc.set_alpha(0.4)
                    pc.set_edgecolor(colors[algo_idx])
                
                vp['cmedians'].set_color(colors[algo_idx])
                vp['cmedians'].set_linewidth(1)
                
                x_jittered = np.random.normal(x, 0.03, size=len(data))
                ax.scatter(x_jittered, data,
                          color=colors[algo_idx],
                          marker=markers[rate_idx],
                          alpha=0.7,
                          s=10,
                          edgecolors='white',
                          linewidths=0.0,
                          label=f'{algorithms[algo_idx]} ({rate_labels[rate_idx]})')
        
                ax.scatter(x, mean_val, 
                         color='grey',
                         edgecolor=colors[algo_idx],
                         marker='s',
                         s=15,
                         zorder=4,
                         linewidth=1)
        
        for rate_idx in range(2):
            x_vals = [point[0] for point in mean_points[rate_idx]]
            y_vals = [point[1] for point in mean_points[rate_idx]]
            ax.plot(x_vals, y_vals, 
                   color='gray', 
                   linestyle='--', 
                   linewidth=1,
                   alpha=0.8,
                   zorder=1)
        
        ax.set_xticks(x_positions)
        ax.set_xticklabels(algorithms)
        ax.set_ylabel("Optimization target value")
        # ax.grid(axis='y', linestyle='--', alpha=0.3)
        
        legend_elements = [
            Line2D([0], [0], marker='o', color='gray', label='C/40',
                  markersize=3, linestyle='None'),
            Line2D([0], [0], marker='^', color='gray', label='C/5',
                  markersize=3, linestyle='None'),
            Line2D([0], [0], markersize=3, color='gray', marker='s', linestyle='--', label='Mean')
        ]
        ax.legend(handles=legend_elements,
                  loc='upper right',
                  bbox_to_anchor=(1.0, 1.05), 
                  handletextpad=0.1, 
                  labelspacing=0.05,
                  frameon=False)
        
        plt.tight_layout()
        plt.show()
        
        
        set_random_seed(123)
        fig, ax = plt.subplots(figsize=(10/2.54, 6/2.54), dpi=600)
        
        x_positions = np.array([1, 2, 3, 4, 5])  
        offset = 0.2  
        
        mean_points = {0: [], 1: []}  # 0: C/40, 1: C/5
        
        fit_std_all = fitopt_all_std_optimization_by_dof[dof]  
        
        for algo_idx in range(len(algorithms)):
            for rate_idx in range(2):
                data = fit_std_all[algo_idx][rate_idx]
                
                # if we calculate CI
                # data_std = fit_std_all[algo_idx][rate_idx]   # calculate CI
                # n_repeat = 10
                # t975_df9 = 2.262  # t.ppf(0.975, 9)
                # data = t975_df9 * data_std / np.sqrt(n_repeat) # 95% CI half-width
                
                x = x_positions[algo_idx] + (-1)**rate_idx * offset
                mean_val = np.mean(data)
                mean_points[rate_idx].append((x, mean_val))
        
                vp = ax.violinplot([data], 
                                  positions=[x],
                                  widths=0.3,
                                  showmeans=False,
                                  showmedians=True,
                                  showextrema=False)
        
                for pc in vp['bodies']:
                    pc.set_facecolor(colors[algo_idx])
                    pc.set_alpha(0.4)
                    pc.set_edgecolor(colors[algo_idx])
        
                vp['cmedians'].set_color(colors[algo_idx])
                vp['cmedians'].set_linewidth(1)
        
                x_jittered = np.random.normal(x, 0.03, size=len(data))
                ax.scatter(x_jittered, data,
                          color=colors[algo_idx],
                          marker=markers[rate_idx],
                          alpha=0.6,
                          s=10,
                          edgecolors='white',
                          linewidths=0.0)
        
                ax.scatter(x, mean_val, 
                         color='grey',
                         edgecolor=colors[algo_idx],
                         marker='s',
                         s=15,
                         zorder=4,
                         linewidth=1)
        
        for rate_idx in range(2):
            x_vals = [point[0] for point in mean_points[rate_idx]]
            y_vals = [point[1] for point in mean_points[rate_idx]]
            ax.plot(x_vals, y_vals, 
                   color='gray', 
                   linestyle='--', 
                   linewidth=1,
                   alpha=0.8,
                   zorder=1)
        
        ax.set_xticks(x_positions)
        ax.set_xticklabels(algorithms)
        ax.set_ylabel("Fitting std per cell")
        # ax.set_ylabel("Per-cell 95% CI half-width")
        
        legend_elements = [
            Line2D([0], [0], marker='o', color='gray', label='C/40',
                  markersize=3, linestyle='None'),
            Line2D([0], [0], marker='^', color='gray', label='C/5',
                  markersize=3, linestyle='None'),
            Line2D([0], [0], markersize=3, color='gray', marker='s', linestyle='--', label='Mean')
        ]
        ax.legend(handles=legend_elements,
                  loc='upper right',
                  bbox_to_anchor=(1.0, 1.05), 
                  handletextpad=0.1, 
                  labelspacing=0.05,
                  frameon=False)
        
        plt.tight_layout()
        plt.show()
    
    
    time_matrix = np.zeros((len(dofs), len(algorithms)))
    fit_error_matrix_C40 = np.zeros((len(dofs), len(algorithms)))
    fit_error_matrix_C5 = np.zeros((len(dofs), len(algorithms)))
    
    for i, dof in enumerate(dofs):
        for j, algo in enumerate(algorithms):
            time_matrix[i, j] = time_consum_all_optimization_by_dof[dof][j]
    
            error_data_C40 = fitopt_all_min_optimization_by_dof[dof][j][0]
            fit_error_matrix_C40[i, j] = np.mean(error_data_C40)
    
            error_data_C5 = fitopt_all_min_optimization_by_dof[dof][j][1]
            fit_error_matrix_C5[i, j] = np.mean(error_data_C5)
    
    colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
    custom_cmap = LinearSegmentedColormap.from_list("colors", ['#DB6C6E', '#F4BA61'])
    fig, ax = plt.subplots(figsize=(6.5/2.54, 4.5/2.54), dpi=600)
    heatmap = sns.heatmap(
        time_matrix,
        annot=False,
        fmt=".1f",
        alpha=0.7,
        cmap=custom_cmap,
        xticklabels=algorithms,
        yticklabels=[f"{d}" for d in dofs],
        annot_kws={"color": "black","size": 9},
        cbar_kws={"shrink": 0.9}  # 
    )
    heatmap.tick_params(axis='both', which='both', length=0)  #
    heatmap.collections[0].colorbar.ax.tick_params(labelsize=9)  
    plt.ylabel("DOF")
    plt.show()
    
    
    custom_cmap = LinearSegmentedColormap.from_list("colors", ['#7BABD2', '#B3C786'])
    # C/5
    fig, ax = plt.subplots(figsize=(6.5/2.54, 4.5/2.54), dpi=600)
    
    heatmap = sns.heatmap((fit_error_matrix_C5+fit_error_matrix_C40)/2, alpha=0.7,annot=False, fmt=".3f", cmap=custom_cmap,
                xticklabels=algorithms, yticklabels=[f"{d}" for d in dofs],
                annot_kws={"color": "black","size": 9},
                cbar_kws={"shrink": 0.9})
    # plt.title("Mean Fit Error (C/5)")
    # plt.xlabel("Algorithm")
    plt.ylabel("DOF")
    heatmap.tick_params(axis='both', which='both', length=0)  # 
    heatmap.collections[0].colorbar.ax.tick_params(labelsize=9)  
    # plt.tight_layout()
    plt.show()


#%%

opt_func_trial = 'DE'
dof = 3
norminal_c = 4.84
# object_losses = ['mse', 'eucl', 'dvf', 'eucl_dvf']
colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
object_losses = ['eucl','mse','dvf','eucl_mse','eucl_dvf']

linestyles = ['--','-.','--',':',':','--','-.']
fig, axs = plt.subplots(1, 1, figsize=(10/ 2.54, 6.5 / 2.54), dpi=600)
# opt_funcs_options= ['PSO', 'DE', 'GA', 'CMA-ES', 'BO']
for cell_index in [93]: #81
    
    # for loss_idx, opt_func_trial in enumerate(opt_funcs_options): 
    for loss_idx, object_loss in enumerate(object_losses):
        filename = f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}.npz"
        print(f"Loading: {filename}")
        data = np.load(filename, allow_pickle=True)
        
        all_fit_results = data['all_fit_results']     # shape: (188, 10, ?)
        all_OCV_fit = data['all_OCV_fit']             # shape: (188, N)
        if cell_index > 93:
            all_cell_ocv = data['all_cell_vmea']           # shape: (188, N, 2)
        else:
            all_cell_ocv = data['all_cell_ocv']  
        
        fitted_Voc = all_OCV_fit[cell_index]          
        measure_Q = all_cell_ocv[cell_index, :, 1]
        measure_V = all_cell_ocv[cell_index, :, 0] * 4.2
        
        # --- Voltage curve ---
        # axs[0].plot(measure_Q*norminal_c, measure_V, 'grey',label='Measured' if loss_idx == 0 else "")
        # axs[0].plot(measure_Q*norminal_c, fitted_Voc, linestyle=linestyles[loss_idx],color=colors[loss_idx], label=opt_func_trial)
    
        # --- dV/dQ curve ---
        dv_fit = np.gradient(fitted_Voc, measure_Q)
        dv_meas = np.gradient(measure_V, measure_Q)
        if loss_idx==0:
            
            axs.plot(measure_Q*norminal_c, -dv_meas, 'grey', alpha=0.8,label='Measured' if loss_idx == 0 else "")
        axs.plot(measure_Q*norminal_c, -dv_fit, linestyle=linestyles[loss_idx],color=colors[loss_idx], alpha=0.8,label=object_loss)


# axs[0].set_xlabel("Q [Ah]")
# axs[0].set_ylabel("Voltage [V]")
# axs[0].legend()
# axs[0].grid(True)


axs.set_xlabel("Q [Ah]")
axs.set_ylabel("dV/dQ [V/Ah]")
axs.legend(loc='best',
           bbox_to_anchor=(0.7, 0.35), 
          handletextpad=0.1, 
          labelspacing=0.05,
          frameon=False)
# axs.grid(True)
axs.set_ylim([0.0, 4])
axs.set_xlim([-0.1, 4.5])

plt.tight_layout()
plt.show()


#%%

opt_func_trial = 'DE'
dof = 3
norminal_c = 4.84
# object_losses = ['mse', 'eucl', 'dvf', 'eucl_dvf']
colors = ['#DB6C6E', '#7BABD2', '#B3C786', '#B283B9','#F4BA61','#D5CA80','#9593C3']
# colors = [ '#7BABD2','#D5CA80']
object_loss = 'eucl'
fig, axs = plt.subplots(1, 1, figsize=(10/ 2.54, 6.5 / 2.54), dpi=600)
object_loss = 'eucl'
labels = ['C/40 Measure','C/40 Fit','C/5 Measure','C/5 Fit', ]
# opt_funcs_options= ['PSO', 'DE', 'GA', 'CMA-ES', 'BO']
loss_idx = 0
for cell_index in [93,93+94]:
   
    # for loss_idx, opt_func_trial in enumerate(opt_funcs_options): 
    
    filename = f"saved_fittings/resval_extract_data_{opt_func_trial}_DOF{dof}_{object_loss}_reverse.npz"
    print(f"Loading: {filename}")
    data = np.load(filename, allow_pickle=True)
    
    all_fit_results = data['all_fit_results']     # shape: (188, 10, ?)
    all_OCV_fit = data['all_OCV_fit']             # shape: (188, N)
    if cell_index > 93:
        all_cell_ocv = data['all_cell_vmea']           # shape: (188, N, 2)
    else:
        all_cell_ocv = data['all_cell_ocv']  
    
    fitted_Voc = all_OCV_fit[cell_index]
    measure_Q = all_cell_ocv[cell_index, :, 1]
    measure_V = all_cell_ocv[cell_index, :, 0] * 4.2
    
    # --- Voltage curve ---
    # axs[0].plot(measure_Q*norminal_c, measure_V, 'grey',label='Measured' if loss_idx == 0 else "")
    # axs[0].plot(measure_Q*norminal_c, fitted_Voc, linestyle=linestyles[loss_idx],color=colors[loss_idx], label=opt_func_trial)

    # --- dV/dQ curve ---
    dv_fit = np.gradient(fitted_Voc, measure_Q)
    dv_meas = np.gradient(measure_V, measure_Q)  
    axs.plot(measure_Q*norminal_c, -dv_meas, linestyle='-',color=colors[loss_idx], alpha=0.8,label=labels[loss_idx])
    axs.plot(measure_Q*norminal_c, -dv_fit, linestyle='--',color=colors[loss_idx+1], alpha=0.8,label=labels[loss_idx+1])
    loss_idx =  loss_idx +2

# axs[0].set_xlabel("Q [Ah]")
# axs[0].set_ylabel("Voltage [V]")
# axs[0].legend()
# axs[0].grid(True)


axs.set_xlabel("Q [Ah]")
axs.set_ylabel("dV/dQ [V/Ah]")
axs.legend(loc='best',
           bbox_to_anchor=(0.7, 0.50), 
          handletextpad=0.1, 
          labelspacing=0.05,
          frameon=False)
# axs.grid(True)
axs.set_ylim([0.0, 4])
axs.set_xlim([-0.1, 4.5])

plt.tight_layout()
plt.show()

#%%
norminal_c = 4.84
epsilon = 1e-6
opt_func = "DE"
dof_list = [2, 3, 4]

results_ratio = {}  # {rate: [ratio_dof2, ratio_dof3, ratio_dof4]}

def unpack_params(params, dof):
    params = np.asarray(params, dtype=np.float64)
    if dof == 4:
        return params
    elif dof == 3:
        Cp, Cn, NP_offset = params
        y0 = 0.0
        x0 = NP_offset
    elif dof == 2:
        NP_ratio, NP_offset = params
        Cp = 1.0
        Cn = Cp * NP_ratio
        y0 = 0.0
        x0 = NP_offset
    else:
        raise ValueError("Unsupported DOF")
    return np.array([Cp, Cn, x0, y0])

def model_Voc(all_params, Q, c_rate):
    Cp, Cn, x0, y0 = all_params
    SOC_p = y0 + Q / Cp
    SOC_n = x0 - Q / Cn
    if c_rate == 'C/40_Cycle':
        Up = OCP_p_40(SOC_p)
        Un = OCP_n_40(SOC_n)
    elif c_rate == 'C/5_Cycle':
        Up = OCP_p(SOC_p)
        Un = OCP_n(SOC_n)
    else:
        raise ValueError("Unsupported rate")
    return Up - Un

def residual(params_dof, Q, V, rate, dof):
    all_params = unpack_params(params_dof, dof)
    return V - model_Voc(all_params, Q, rate)

def compute_jacobian(params_dof, Q, V, rate, dof):
    J = np.zeros((len(Q), len(params_dof)))
    for i in range(len(Q)):
        func_i = lambda p: residual(p, Q, V, rate, dof)[i]
        J[i, :] = approx_fprime(params_dof, func_i, epsilon)
    return J

# 
for DOF in dof_list:
    filename = f"saved_fittings/resval_extract_data_{opt_func}_DOF{DOF}_{object_loss}.npz"
    data = np.load(filename, allow_pickle=True)

    Cp_all = data['all_Cp_opt'] * norminal_c
    Cn_all = data['all_Cn_opt'] * norminal_c
    x0_all = data['all_x0_opt']
    y0_all = data['all_y0_opt']
    cells = data['all_cells']
    vmea = data['all_cell_vmea']

    rate_labels = cells[:, 1]
    unique_rates = np.unique(rate_labels)

    for rate in unique_rates:
        idxs = np.where(rate_labels == rate)[0]
        ranks = []
        print(f"\n⏱️ Checking {rate} (DOF={DOF})...")
        for i in idxs:
            Q = vmea[i][:, 0]
            V = vmea[i][:, 1] * 4.2

            if DOF == 4:
                params_dof = [Cp_all[i], Cn_all[i], x0_all[i], y0_all[i]]
            elif DOF == 3:
                params_dof = [Cp_all[i], Cn_all[i], x0_all[i]]
            elif DOF == 2:
                Cp = 1.0
                Cn = Cn_all[i]
                NP_ratio = Cn / Cp
                params_dof = [NP_ratio, x0_all[i]]

            try:
                J = compute_jacobian(params_dof, Q, V, rate, DOF)
                rank = matrix_rank(J, tol=1e-6)
            except Exception as e:
                print(f"❌ Error: {e}")
                rank = -1
            ranks.append(rank)

        ranks = np.array(ranks)
        full_rank_ratio = np.sum(ranks == DOF) / len(ranks)
        total = len(ranks)
        full_rank_count = np.sum(ranks == DOF)
        print(f"✔ Full rank count: {full_rank_count} / {total}")
        if rate not in results_ratio:
            results_ratio[rate] = []
        results_ratio[rate].append(full_rank_ratio)

# 
rate_names = list(results_ratio.keys())
heatmap_data = np.array([results_ratio[rate] for rate in rate_names])

# 
plt.figure(figsize=(6, 3), dpi=600)
sns.heatmap(heatmap_data, annot=True, fmt=".2f", cmap="YlGnBu",
            xticklabels=[f"DOF={d}" for d in dof_list],
            yticklabels=rate_names,
            cbar_kws={'label': 'Full Rank Ratio'})
plt.title("Jacobian Full-Rank Ratio for Different DOFs")
plt.tight_layout()
plt.show()

#%%
filename = f"saved_fittings/resval_extract_data_DE_DOF3_eucl.npz"
print(filename)
norminal_c = 4.84
data = np.load(filename, allow_pickle=True)
all_Cq = data['all_Cq']
all_Cp_opt = data['all_Cp_opt']*norminal_c
all_Cn_opt = data['all_Cn_opt']*norminal_c
all_x0_opt = data['all_x0_opt']
all_y0_opt = data['all_y0_opt']
all_OCV_fit = data['all_OCV_fit']
all_cell_cap = data['all_cell_cap']
all_cell_ocv = data['all_cell_ocv']
all_cell_vmea = data['all_cell_vmea']
all_cells = data['all_cells']
all_fit_results = data['all_fit_results']
time_consum_all_optimization.append(data['time_consum']/60)
all_cell_Vreal = all_cell_ocv[:,:,0]*4.2
all_cell_Qreal = all_cell_ocv[:,:,1]
all_cell_Vm = all_cell_vmea[:,:,0]*4.2
all_cell_Qm = all_cell_vmea[:,:,1]
all_cell_Crate = all_cell_vmea[:,:,2]
all_v_diff = all_cell_Vreal-all_OCV_fit
all_q_diff = all_cell_Qreal*norminal_c-all_cell_Qm*norminal_c


data_dict = {key: data[key] for key in data.files}
# print("Keys in the file:", data.files)
all_cells = data['all_cells']  
cell_names = all_cells[:, 0]  
rate_labels = all_cells[:, 1]  
unique_labels = np.unique(rate_labels)

split_data_dict = {label: {} for label in unique_labels}

for key in data_dict:
    if key!='time_consum':
        if data_dict[key].shape[0] == len(rate_labels):  
            for label in unique_labels:
                split_data_dict[label][key] = data_dict[key][rate_labels == label]
        else:
            for label in unique_labels:
                split_data_dict[label][key] = data_dict[key]


# for label in unique_labels:
#     print(f"Data for rate {label}: keys -> {split_data_dict[label].keys()}")

data_sets = []
# ['C/40', 'C/5']
assert len(unique_labels) == 2, "C/5 and C/40, need to change to suit more C-rates"

label1, label2 = unique_labels  
subset1, subset2 = split_data_dict[label1], split_data_dict[label2]

data_sets = [
    (subset1['all_Cp_opt']*norminal_c, subset2['all_Cp_opt']*norminal_c, f'Cp {label1}', f'Cp {label2}'),
    (subset1['all_Cn_opt']*norminal_c, subset2['all_Cn_opt']*norminal_c, f'Cn {label1}', f'Cn {label2}'),
    (subset1['all_Cn_opt'] * subset1['all_x0_opt']*norminal_c + subset1['all_Cp_opt'] * subset1['all_y0_opt']*norminal_c, 
     subset2['all_Cn_opt'] * subset2['all_x0_opt']*norminal_c + subset2['all_Cp_opt'] * subset2['all_y0_opt']*norminal_c, 
     f'Cli {label1}', f'Cli {label2}')
]

input_ml_train = []
output_ml_train = []
input_ml_test = []
output_ml_test = []
Q_real_train = []
V_real_train = []
Q_real_test =[]
V_real_test = []
Q_mea_train = []
Q_mea_test = []
V_fit_train = []
V_fit_test = []
V_mea_train = []
V_mea_test = []
# train_cells = np.arange(0,len(all_Cp_opt),2)
# test_cells = np.arange(1,len(all_Cp_opt),2)
# for i in train_cells:
#     for j in range(all_OCV_fit.shape[1]):  
        
#         input_features = [all_Cp_opt[i], all_Cn_opt[i], all_x0_opt[i], all_y0_opt[i], all_OCV_fit[i, j], 
#                           all_cell_Vm[i, j], all_cell_Qm[i, j],all_cell_Crate[i, j]]
#         input_ml_train.append(input_features)
#         output_ml_train.append([all_v_diff[i, j],all_q_diff[i, j]]) 
        
# for i in test_cells:
#     for j in range(all_OCV_fit.shape[1]):  
        
#         input_features = [all_Cp_opt[i], all_Cn_opt[i], all_x0_opt[i], all_y0_opt[i], all_OCV_fit[i, j], 
#                           all_cell_Vm[i, j], all_cell_Qm[i, j],all_cell_Crate[i, j]]
#         input_ml_test.append(input_features)
#         output_ml_test.append([all_v_diff[i, j],all_q_diff[i, j]]) 

# input_ml_train = np.array(input_ml_train) 
# output_ml_train = np.array(output_ml_train)  
    
# input_ml_test = np.array(input_ml_test) 
# output_ml_test = np.array(output_ml_test)      

out_put_c_rate = 'C/5_Cycle'

for label in ['C/5_Cycle']:   #unique_labels  ['C/40_Cycle', 'C/5_Cycle']
    subset = split_data_dict[label]  
    
    train_cells = np.arange(0, len(subset['all_Cp_opt']), 2)
    test_cells = np.arange(1, len(subset['all_Cp_opt']), 2)

    for i in train_cells:
        if subset['all_cell_vmea'][i,-1,1]<0.1:
            print('skip meausre Q less than 0.1')
            continue
        for j in range(subset['all_OCV_fit'].shape[1]):  
            input_features = [
                subset['all_Cp_opt'][i], subset['all_Cn_opt'][i], 
                # subset['all_x0_opt'][i], subset['all_y0_opt'][i],
                subset['all_Cn_opt'][i] * subset['all_x0_opt'][i] + subset['all_Cp_opt'][i] * subset['all_y0_opt'][i],
                subset['all_OCV_fit'][i, j]/4.2, subset['all_cell_vmea'][i, j, 0], subset['all_cell_vmea'][i, j,1],
                -1*subset['all_cell_vmea'][i, j,2]
            ]
            input_ml_train.append(input_features)
            if out_put_c_rate == 'C/40_Cycle':
                output_ml_train.append([subset['all_cell_ocv'][i, j, 0]-subset['all_OCV_fit'][i, j]/4.2, 
                                        subset['all_cell_ocv'][i, j, 1]-subset['all_cell_vmea'][i, j, 1]])
            elif out_put_c_rate == 'C/5_Cycle':
                output_ml_train.append([subset['all_cell_vmea'][i, j, 0]-subset['all_OCV_fit'][i, j]/4.2, 
                                        subset['all_cell_vmea'][i, j, 1]-subset['all_cell_vmea'][i, j, 1]])
        if out_put_c_rate == 'C/40_Cycle':        
            Q_real_train.append(subset['all_cell_ocv'][i, :, 1]*nominal_capacity)
            V_real_train.append(subset['all_cell_ocv'][i, :, 0]*4.2)
            Q_mea_train.append(subset['all_cell_vmea'][i, :, 1]*nominal_capacity) 
            V_fit_train.append(subset['all_OCV_fit'][i, :])
            V_mea_train.append(subset['all_cell_vmea'][i, :, 0]*4.2)
        elif out_put_c_rate == 'C/5_Cycle':
            Q_real_train.append(subset['all_cell_vmea'][i, :, 1]*nominal_capacity)
            V_real_train.append(subset['all_cell_vmea'][i, :, 0]*4.2)
            Q_mea_train.append(subset['all_cell_vmea'][i, :, 1]*nominal_capacity) 
            V_fit_train.append(subset['all_OCV_fit'][i, :])
            V_mea_train.append(subset['all_cell_vmea'][i, :, 0]*4.2)
            
    for i in test_cells:
        if subset['all_cell_vmea'][i,-1,1]<0.1:
            print('skip meausre Q less than 0.1')
            continue
        for j in range(subset['all_OCV_fit'].shape[1]):  
            input_features = [
                subset['all_Cp_opt'][i], subset['all_Cn_opt'][i], 
                # subset['all_x0_opt'][i], subset['all_y0_opt'][i],
                subset['all_Cn_opt'][i] * subset['all_x0_opt'][i] + subset['all_Cp_opt'][i] * subset['all_y0_opt'][i],
                subset['all_OCV_fit'][i, j]/4.2, subset['all_cell_vmea'][i, j,0], subset['all_cell_vmea'][i, j,1],
                -1*subset['all_cell_vmea'][i, j,2]
            ]
            input_ml_test.append(input_features)
            if out_put_c_rate == 'C/40_Cycle':
                # gradient( Real_V, Real_Q )
                output_ml_test.append([subset['all_cell_ocv'][i, j, 0]-subset['all_OCV_fit'][i, j]/4.2, 
                                        subset['all_cell_ocv'][i, j, 1]-subset['all_cell_vmea'][i, j, 1]])
            elif out_put_c_rate == 'C/5_Cycle':
                output_ml_test.append([subset['all_cell_vmea'][i, j, 0]-subset['all_OCV_fit'][i, j]/4.2, 
                                        subset['all_cell_vmea'][i, j, 1]-subset['all_cell_vmea'][i, j, 1]])
        if out_put_c_rate == 'C/40_Cycle': 
            Q_real_test.append(subset['all_cell_ocv'][i, :, 1]*nominal_capacity)
            V_real_test.append(subset['all_cell_ocv'][i, :, 0]*4.2)
            Q_mea_test.append(subset['all_cell_vmea'][i, :, 1]*nominal_capacity) 
            V_fit_test.append(subset['all_OCV_fit'][i, :])
            V_mea_test.append(subset['all_cell_vmea'][i, :, 0]*4.2)
        elif out_put_c_rate == 'C/5_Cycle': 
            Q_real_test.append(subset['all_cell_vmea'][i, :, 1]*nominal_capacity)
            V_real_test.append(subset['all_cell_vmea'][i, :, 0]*4.2)
            Q_mea_test.append(subset['all_cell_vmea'][i, :, 1]*nominal_capacity) 
            V_fit_test.append(subset['all_OCV_fit'][i, :])
            V_mea_test.append(subset['all_cell_vmea'][i, :, 0]*4.2)
            
input_ml_train = np.array(input_ml_train)
output_ml_train = np.array(output_ml_train)
input_ml_test = np.array(input_ml_test)
output_ml_test = np.array(output_ml_test)

input_ml_all = np.vstack((input_ml_train, input_ml_test))  # Stack all input features
output_ml_all = np.vstack((output_ml_train, output_ml_test))  # Stack all output features
combined_data = np.hstack((input_ml_all, output_ml_all))  # Stack them side by side
input_columns = [
    'Cp', 'Cn', 'Cli', 'V_fit', 'V_mea', 'Q_mea', 'Crate'
]  
output_columns = ['V_diff', 'Q_diff']  # Output column names

columns = input_columns + output_columns  # Create the complete list of column names

combined_df = pd.DataFrame(combined_data, columns=columns)

correlation_matrix = abs(combined_df.corr())
# Plot the heatmap
plt.figure(figsize=(15 / 2.54, 10 / 2.54), dpi=600)
plt.ion()
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5, vmin=0, vmax=1)
plt.title("Correlation Heatmap")
plt.show()

# rf_model = RandomForestRegressor(n_estimators=200, random_state=123)
# rf_model.fit(input_ml_train, output_ml_train)
# y_pred = rf_model.predict(input_ml_test)


# xgb_model = XGBRegressor(
#     n_estimators=100, 
#     max_depth=15,       
#     learning_rate=0.1, 
#     random_state=123   
# )

xgb_model = XGBRegressor(random_state=123)

# Define the parameter grid
param_grid = {
    'n_estimators': [50, 100, 200],  # Number of trees
    'max_depth': [ 5, 10, 20],  # Tree depth
    'learning_rate': [0.01, 0.05, 0.1],  # Step size
    # 'subsample': [0.8, 1.0],  # Fraction of samples per tree
    # 'colsample_bytree': [0.8, 1.0],  # Fraction of features per tree
    # 'gamma': [0, 0.1, 0.2],  # Minimum loss reduction required for a split
    # 'reg_alpha': [0, 0.01, 0.1],  # L1 regularization
    # 'reg_lambda': [0.1, 1, 10]  # L2 regularization
}

# Perform Grid Search with cross-validation
grid_search = GridSearchCV(
    estimator=xgb_model,
    param_grid=param_grid,
    scoring='neg_root_mean_squared_error',  # or r2
    cv=5,  # 5-fold cross-validation
    n_jobs=-1,  # Use all processors
    verbose=1 # Display progress
)

# Fit GridSearchCV
grid_search.fit(input_ml_train, output_ml_train)

# Print the best parameters and best score
print("Best Parameters:", grid_search.best_params_)
# print("Best R² Score:", grid_search.best_score_)

xgb_model_best = grid_search.best_estimator_
xgb_model_best.fit(input_ml_train, output_ml_train)
y_pred = xgb_model_best.predict(input_ml_test)


mse_v = mean_squared_error(output_ml_test[:,0].reshape(-1)*4.2, y_pred[:,0].reshape(-1)*4.2)
print("Mean Squared Error of Predicted Voltage Residuals [mV]:", mse_v*1000)
mse_v_m = mean_squared_error(output_ml_test[:,0].reshape(-1)*4.2, 0*output_ml_test[:,0].reshape(-1)*4.2)
print("Mean Squared Error of Measured Voltage Residuals [mV]:", mse_v_m*1000)

mse_q = mean_squared_error(output_ml_test[:,1].reshape(-1)*norminal_c, y_pred[:,1].reshape(-1)*norminal_c)
print("Mean Squared Error of Predicted Capacity Residuals [mAh]:", mse_q*1000)
mse_q = mean_squared_error(output_ml_test[:,1].reshape(-1)*norminal_c, 0*y_pred[:,1].reshape(-1)*norminal_c)
print("Mean Squared Error of Measured Capacity Residuals [mAh]:", mse_q*1000)

cmap = mcolors.LinearSegmentedColormap.from_list('custom_cmap',['#E59693','#0073B1'])
orig_color_values = np.abs(output_ml_test[:,0].reshape(-1)*4.2-y_pred[:,0].reshape(-1)*4.2)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)

plt.figure(figsize=(20 / 2.54, 8 / 2.54), dpi=600)
plt.ion()
plt.subplot(121)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# plt.scatter(output_ml_test[:,0].reshape(-1)*4.2,y_pred[:,0].reshape(-1)*4.2)
scatter1 = plt.scatter(output_ml_test[:,0].reshape(-1)*4.2, y_pred[:,0].reshape(-1)*4.2, 
                      c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(output_ml_test[:,0].reshape(-1)*4.2,output_ml_test[:,0].reshape(-1)*4.2,color='grey',linewidth=2)
plt.xlabel('Real values of voltage residuals [V]')
plt.ylabel('Predictions of voltage residuals [V]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

orig_color_values = np.abs(output_ml_test[:,1].reshape(-1)*norminal_c-y_pred[:,1].reshape(-1)*norminal_c)
norm = Normalize(vmin=orig_color_values.min(), vmax=orig_color_values.max())
color_values = norm(orig_color_values)
plt.subplot(122)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.tick_params(top='on', right='on', which='both')
plt.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)
# plt.scatter(output_ml_test[:,1].reshape(-1)*norminal_c,y_pred[:,1].reshape(-1)*norminal_c)
scatter2 = plt.scatter(output_ml_test[:,1].reshape(-1)*norminal_c, y_pred[:,1].reshape(-1)*norminal_c, 
                      c=color_values, alpha=0.9, cmap='coolwarm', marker='o', linewidth=0.0, s=10, edgecolors=None)
plt.plot(output_ml_test[:,1].reshape(-1)*norminal_c,output_ml_test[:,1].reshape(-1)*norminal_c,color='grey',linewidth=2)
plt.xlabel('Real values of capacity residuals [Ah]')
plt.ylabel('Predictions of capacity residuals [Ah]')
cbar = plt.colorbar()
cbar.set_label('Absolute error')
plt.tick_params(bottom=False, left=False)
# cbar.set_label('Normalized Color values')
ticks = np.linspace(orig_color_values.min(), orig_color_values.max(), num=3)
tick_labels = ["{:.2f}".format(value) for value in ticks]
cbar.set_ticks(norm(ticks))
cbar.set_ticklabels(tick_labels)

plt.tight_layout()
plt.show()
